import copy
import functools
import json
import logging
from pathlib import Path
from typing import Callable, Tuple

import hydra
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm

from diffusion_bandit import utils
from diffusion_bandit.conditional_generation import ScoreRewardModel
from diffusion_bandit.dataset_generation import distance_to_sphere
from diffusion_bandit.diffusion import DiffusionProcess
from diffusion_bandit.linear_ts_plotting import plot_bandit_diffusion
from diffusion_bandit.linear_ts_utils import initialize_parameters, update_posterior
from diffusion_bandit.neural_networks.shape_reward_nets import (
    get_ground_truth_reward_model,
)
from diffusion_bandit.samplers import Sampler, get_sampler

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@hydra.main(
    version_base=None, config_path="configs", config_name="linear_thompson_sampling"
)
def main(config: DictConfig) -> None:
    """
    Main function to execute linear Thompson sampling.
    """
    factor = config.misspec.factor

    logger.info(OmegaConf.to_yaml(config))
    utils.seeding.seed_everything(config)

    # Setup
    model_load_path = Path(config.outputs_dir) / f"{config.names.score_model}.pth"
    saved_dict = torch.load(model_load_path, weights_only=False)
    device = torch.device(config.sampler.device)

    # Extract parameters from saved_dict
    dataset_config = saved_dict["dataset_config"]["dataset"]
    d_ext, radius, surface = [
        dataset_config[key] for key in ["d_ext", "radius", "surface"]
    ]
    projector, x_data, score_model, beta_min, beta_max = [
        saved_dict[key]
        for key in ["projector", "x_data", "score_model", "beta_min", "beta_max"]
    ]

    diffusion_process = DiffusionProcess(
        beta_min=beta_min,
        beta_max=beta_max,
    )

    # Set up reward model
    reward_model = get_ground_truth_reward_model(
        d_ext=d_ext,
        projector=projector,
        radius=radius * factor,
        surface=surface,
        name=config.reward_model.name,
        diffusion_process=diffusion_process,
        score=copy.deepcopy(score_model),
    )

    # Model setup
    for model in [score_model, reward_model]:
        model.to(device)
    reward_model.train()
    score_model.eval()
    for param in reward_model.parameters():
        param.reward_model = True

    # Get sampler
    sampler = get_sampler(
        name=config.names.sampler,
        shape=[d_ext],
        diffusion_process=diffusion_process,
        alpha_max=config.conditional.alpha_max,
        **config.sampler,
    )

    x_data = torch.FloatTensor(x_data).to(device)
    projector = projector.to(device)  # Ensure projector is on the correct device

    manifold_distance_fn = functools.partial(
        distance_to_sphere,
        projector=projector,
        radius=radius * factor,
        surface=surface,
    )

    all_results = []

    for run_idx in range(config.thompson.num_runs):
        logger.info(f"Starting run {run_idx + 1}/{config.thompson.num_runs}")
        mean, cov, theta_gt = initialize_parameters(
            d_ext, config.thompson.prior_var, device
        )

        reward_model_gt = copy.deepcopy(reward_model)
        reward_model_gt.layer.weight.data = theta_gt

        results = thompson_sampling(
            config,
            x_data,
            sampler,
            score_model,
            reward_model_gt,
            mean,
            cov,
            manifold_distance_fn,
            device,
            projector,
            radius,
            factor,
        )
        all_results.append(results)

    plot_bandit_diffusion(all_results, projector)

    results_save_path = Path(config.outputs_dir) / "all_results_diff.pth"
    torch.save(all_results, results_save_path)
    logger.info(f"All results saved to {results_save_path}")


def thompson_sampling(
    config: DictConfig,
    x_data: torch.Tensor,
    sampler: Sampler,
    score_model: torch.nn.Module,
    reward_model: torch.nn.Module,
    mean: torch.Tensor,
    cov: torch.Tensor,
    manifold_distance_fn: Callable,
    device: torch.device,
    projector: torch.Tensor,
    radius: float,
    factor: float,
):
    """
    Perform Thompson sampling for optimization.
    """

    noise_var = config.thompson.noise_var
    thompson_iterates = config.thompson.thompson_iterates

    theta_gt = reward_model.layer.weight.data.clone().detach()
    max_obtainable_omega = factor * radius * torch.linalg.norm(theta_gt @ projector)
    with torch.no_grad():
        max_obtainable_data = torch.max(reward_model(x_data))

    results = {
        "rewards_gt": [],
        "rewards_gt_noisy": [],
        "rewards_iterate": [],
        "condition": [],
        "distances": [],
        "posterior_mean": [],
        "posterior_cov": [],
        "theta_iterate": [],
        "theta_gt": reward_model.layer.weight.data.clone().cpu().numpy(),
        "max_obtainable": max_obtainable_omega.clone().cpu().numpy(),
        "max_obtainable_data": max_obtainable_data.clone().cpu().numpy(),
    }

    reward_model_iterate = copy.deepcopy(reward_model)
    reward_model_iterate.set_mode(mode="simple")

    idx = 0
    mean_dist = 0
    while idx < thompson_iterates:
        theta_iterate = torch.distributions.MultivariateNormal(mean, cov).sample()
        reward_model_iterate.layer.weight.data = theta_iterate.clone().detach()

        max_obtainable_iteration = (
            factor * radius * torch.linalg.norm(theta_iterate @ projector)
        )
        print(max_obtainable_iteration)

        with torch.no_grad():
            dataset_max = torch.max(reward_model_iterate(x_data)).item()

        reward_model_iterate.set_mode(mode="sampling")

        if config.oracle.mode == "binary":
            (
                condition,
                final_path,
            ) = binary_search_oracle(
                sampler=sampler,
                score_model=score_model,
                reward_model_iterate=reward_model_iterate,
                manifold_distance_fn=manifold_distance_fn,
                config=config,
                device=device,
                lower_bound=dataset_max,
            )
        elif config.oracle.mode == "data_max":
            condition, final_path = dataset_maximizer(
                reward_model_iterate=reward_model_iterate,
                max_reward=dataset_max,
                sampler=sampler,
                score_model=score_model,
                config=config,
                device=device,
            )
        else:
            raise NotImplementedError(
                f"Oracle mode '{config.oracle.mode}' is not implemented."
            )

        with torch.no_grad():
            # Set the mode to "simple" for the reward model iterate
            reward_model_iterate.set_mode(mode="simple")
            rewards_iterate = reward_model_iterate(final_path)  # Shape: (batch_size,)
            distances = manifold_distance_fn(final_path)  # Shape: (batch_size,)

            mask = distances < config.oracle.allowed_dist

            if torch.any(mask):
                filtered_rewards = rewards_iterate[mask]
                mask_indices = torch.nonzero(mask, as_tuple=False).squeeze(1)
                max_filtered_reward_idx = torch.argmax(filtered_rewards)
                selected_idx = mask_indices[max_filtered_reward_idx]
                selected_point = final_path[selected_idx].unsqueeze(0)
            else:  # no points satisfy the constraint
                selected_idx = torch.argmin(distances)
                selected_point = final_path[selected_idx].unsqueeze(
                    0
                )  # Shape: (1, ...)

            reward_model.set_mode(mode="feedback")
            rewards_gt_selected = reward_model(selected_point)

            ############### TODO: just for info
            gt = reward_model(final_path)
            min_deviation_idx = torch.argmin(torch.abs(gt - condition))
            mean_dist = (mean_dist * idx + distances[min_deviation_idx]) / (idx + 1)
            print(f"Mean distance in iter {idx}: {mean_dist.item()}")
            ################

            # Update variables as needed
            rewards_gt = rewards_gt_selected
            rewards_iterate = rewards_iterate[selected_idx].unsqueeze(0)
            condition = condition[selected_idx].unsqueeze(0)
            distances = distances[selected_idx].unsqueeze(0)
            x_new = selected_point

        noisy_rewards = rewards_gt + torch.randn_like(rewards_gt) * (noise_var**0.5)

        # Update posterior
        new_mean, new_cov = update_posterior(
            mean=mean.clone().detach(),
            cov=cov.clone().detach(),
            x_new=x_new,
            y_new=noisy_rewards,
            noise_var=noise_var,
            device=device,
        )

        if torch.isnan(new_mean).any() or torch.isnan(new_cov).any():
            logger.warning(
                "NaN detected in mean or covariance matrix. Ignoring this sample."
            )
            continue

        mean = new_mean
        cov = new_cov

        # Store results
        results["rewards_gt"].append(rewards_gt.cpu().numpy())
        results["rewards_gt_noisy"].append(noisy_rewards.cpu().numpy())
        results["rewards_iterate"].append(rewards_iterate.cpu().numpy())
        results["condition"].append(condition.cpu().numpy())
        results["distances"].append(distances.cpu().numpy())
        results["posterior_mean"].append(mean.cpu().numpy())
        results["posterior_cov"].append(cov.cpu().numpy())
        results["theta_iterate"].append(theta_iterate.cpu().numpy())

        idx += 1

    return results


def binary_search_oracle(
    sampler: Sampler,
    score_model: torch.nn.Module,
    reward_model_iterate: torch.nn.Module,
    manifold_distance_fn: Callable,
    config: DictConfig,
    device: torch.device,
    lower_bound: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Perform a binary search to find the appropriate condition threshold.
    """
    upper_bound = config.oracle.upper_bound
    for oracle_iter in range(config.oracle.max_iters):
        mid_point = lower_bound + (upper_bound - lower_bound) / 2
        print(lower_bound, mid_point, upper_bound)
        mid_point_t = torch.tensor(
            mid_point, device=device
        )  # Convert to tensor on the correct device
        condition = mid_point_t * torch.ones(
            (config.oracle.num_generations,), device=device
        )

        # Initialize the ScoreRewardModel with the current condition
        reward_model_iterate.set_mode(mode="sampling")
        score_reward_model_fn = ScoreRewardModel(
            condition=condition,
            reward_model=reward_model_iterate,
            score_model=score_model,
            variance=config.conditional.variance,
        )
        # Sample paths using the sampler
        paths = sampler.sample(
            score_model=score_reward_model_fn,
            batch_size=config.oracle.num_generations,
            num_steps=config.sampler.num_steps,
            eps=config.sampler.eps,
        )

        final_path = paths[-1].to(device)  # Shape: (batch_size, d_ext)

        with torch.no_grad():
            reward_model_iterate.set_mode(mode="simple")
            mid_point_rewards = reward_model_iterate(final_path).squeeze(
                -1
            )  # Shape: (batch_size,)
            distances = manifold_distance_fn(final_path)  # Shape: (batch_size,)

        # Determine if any sample meets the condition
        condition_met = torch.logical_and(
            distances < config.oracle.dist_thresh, mid_point_rewards > mid_point_t
        )

        print("reward difference", mid_point_rewards - mid_point_t)
        print("distances", distances)
        print("condition", condition_met)

        if condition_met.any():
            # If condition is met, increase the lower bound
            lower_bound = mid_point
        else:
            # If condition is not met, decrease the upper bound
            upper_bound = mid_point

        # Check if the interval is within the desired threshold
        if (upper_bound - lower_bound) < config.oracle.interval_thresh:
            logger.info(f"Binary search converged in {oracle_iter + 1} iterations.")
            logger.info(f"Final values: {lower_bound, upper_bound}")
            break
    else:
        logger.warning(
            "Binary search did not converge within the maximum number of iterations."
        )

    # Final condition and sampling after binary search
    final_condition = torch.tensor(lower_bound, device=device) * torch.ones(
        (config.oracle.num_generations,), device=device
    )
    final_score_reward_model_fn = ScoreRewardModel(
        condition=final_condition,
        reward_model=reward_model_iterate,
        score_model=score_model,
        variance=config.conditional.variance,
    )
    paths = sampler.sample(
        score_model=final_score_reward_model_fn,
        batch_size=config.oracle.num_generations,
        num_steps=config.sampler.num_steps,
        eps=config.sampler.eps,
    )
    final_path = paths[-1].to(device)

    return final_condition, final_path


def dataset_maximizer(
    reward_model_iterate: torch.nn.Module,
    max_reward: float,
    sampler: Sampler,
    score_model: torch.nn.Module,
    config: DictConfig,
    device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Maximizes the dataset by generating samples based on the maximum reward in the dataset.
    """
    # Create a condition tensor based on the maximum reward
    condition = torch.tensor(max_reward, device=device) * torch.ones(
        (config.oracle.num_generations,), device=device
    )

    # Set the reward model to sampling mode if applicable
    if hasattr(reward_model_iterate, "set_mode"):
        reward_model_iterate.set_mode(mode="sampling")

    # Initialize the ScoreRewardModel with the current condition
    score_reward_model_fn = ScoreRewardModel(
        condition=condition,
        reward_model=reward_model_iterate,
        score_model=score_model,
        variance=config.conditional.variance,
    )

    # Generate samples using the sampler
    paths = sampler.sample(
        score_model=score_reward_model_fn,
        batch_size=config.oracle.num_generations,
        num_steps=config.sampler.num_steps,
        eps=config.sampler.eps,
    )

    # Extract the final path from the generated samples
    final_path = paths[-1].to(device)  # Shape: (batch_size, d_ext)

    return condition, final_path


if __name__ == "__main__":
    main()  # pylint: disable=no-value-for-parameter
